#----------------------------------------------
# -*- encoding=utf-8 -*-                      #
# __author__:'xiaojie'                        #
# CreateTime:                                 #
#       2019/4/12 10:38                       #
#                                             #
#               天下风云出我辈，                 #
#               一入江湖岁月催。                 #
#               皇图霸业谈笑中，                 #
#               不胜人生一场醉。                 #
#----------------------------------------------

# 总结：该示例代码用的还是WGAN，损失函数还是原始的WGAN的损失函数，只不过，还用了谱归一化的技术。
#谱归一化

# 参考《深度学习中的Lipschitz约束：泛化与生成模型》：
# https://zhuanlan.zhihu.com/p/46924315
# https://github.com/bojone/gan/blob/master/keras/wgan_sn_celeba.py

# 实现基于“谱归一化”的Keras代码，实现方式是添加kernel_constraint
# 注意使用代码前还要修改Keras源码，修改
# keras/engine/base_layer.py的Layer对象的add_weight方法


import numpy as np
from scipy import misc
import glob
from keras.models import Model
from keras.layers import *
from keras import backend as K
from keras.optimizers import Adam
import os
from keras.utils import plot_model

if not os.path.exists('samples'):
    os.mkdir('samples')

imgs = glob.glob('../../images/img_align_celeba/*.jpg')
np.random.shuffle(imgs)


height,width = misc.imread(imgs[0]).shape[:2]
center_height = int((height-width)/2)
img_dim = 64
z_dim = 100

def imread(f):
    x = misc.imread(f)
    x=x[center_height:center_height+width,:]
    x = misc.imresize(x,(img_dim,img_dim))
    return x.astype(np.float32)/255*2-1

def data_generator(batch_size = 32):
    X=[]
    while True:
        np.random.shuffle(imgs)
        for f in imgs:
            X.append(imread(f))
            if len(X) == batch_size:
                X = np.array(X)
                yield X
                X=[]

def spectral_norm(w,r=5):
    w_shape=K.int_shape(w)
    in_dim=np.prod(w_shape[:-1]).astype(int)
    out_dim = w_shape[-1]
    w = K.reshape(w,(in_dim,out_dim))
    u = K.ones((1,in_dim))
    for i in range(r):
        v = K.l2_normalize(K.dot(u,w))
        u = K.l2_normalize(K.dot(v,K.transpose(w)))
    return K.sum(K.dot(K.dot(u,w),K.transpose(v)))

def spectral_normalization(w):
    return w/spectral_norm(w)

# 判别器
x_in = Input(shape=(img_dim,img_dim,3))
x=x_in

x=Conv2D(img_dim,
         (5,5),
         strides=(2,2),
         padding='same',
         kernel_constraint=spectral_normalization)(x)
x=LeakyReLU()(x)

for i in range(3):
    x=Conv2D(img_dim*2**(i+1),
             (5,5),
             strides=(2,2),
             padding='same',
             kernel_constraint=spectral_normalization)(x)
    x=BatchNormalization(gamma_constraint=spectral_normalization)(x)
    x=LeakyReLU()(x)

x=Flatten()(x)
x=Dense(1,use_bias=False,kernel_constraint=spectral_normalization)(x)

d_model=Model(x_in,x)
plot_model(d_model,to_file='./png/w_gan_sn_d_model.png',show_shapes=True)

# 生成器
z_in = Input(shape=(z_dim,))
z=z_in

z=Dense(4*4*img_dim*8)(z)
z=BatchNormalization()(z)
z=Activation('relu')(z)
z=Reshape((4,4,img_dim*8))(z)

for i in range(3):
    z=Conv2DTranspose(img_dim*4//2**i,(5,5),strides=(2,2),padding='same')(z)
    z=BatchNormalization()(z)
    z=Activation('relu')(z)

z=Conv2DTranspose(3,(5,5),strides=(2,2),padding='same')(z)
z=Activation('tanh')(z)

g_model=Model(z_in,z)
plot_model(d_model,to_file='./png/w_gan_sn_g_model.png',show_shapes=True)

# 整合模型(训练判别器)
x_in = Input(shape=(img_dim,img_dim,3))
z_in=Input(shape=(z_dim,))
g_model.trainable=False

x_fake=g_model(z_in)
x_real_score=d_model(x_in)
x_fake_score=d_model(x_fake)

d_train_model=Model([x_in,z_in],[x_real_score,x_fake_score])

d_loss=K.mean(x_fake_score-x_real_score)
d_train_model.add_loss(d_loss)
d_train_model.compile(optimizer=Adam(2e-4,0.5))

#整合模型(训练生成器)
g_model.trainable=True
d_model.trainable=False
x_fake_score=d_model(g_model(z_in))

g_train_model=Model(z_in,x_fake_score)
g_train_model.add_loss(K.mean(-x_fake_score))
g_train_model.compile(optimizer=Adam(2e-4,0.5))


#检查模型结构
plot_model(d_train_model,to_file='./png/w_gan_sn_d_train_model.png',show_shapes=True)
plot_model(g_train_model,to_file='./png/w_gan_sn_g_train_model.png',show_shapes=True)


#采样函数
def sample(path):
    n = 9
    figure = np.zeros((img_dim * n, img_dim * n, 3))
    for i in range(n):
        for j in range(n):
            z_sample = np.random.randn(1, z_dim)
            x_sample = g_model.predict(z_sample)
            digit = x_sample[0]
            figure[i * img_dim:(i + 1) * img_dim,
            j * img_dim:(j + 1) * img_dim] = digit
    figure = (figure + 1) / 2 * 255
    figure = np.round(figure, 0).astype(int)
    misc.imsave(path, figure)


iter_per_sample=100
total_iter=1000000
batch_size=20
img_generator=data_generator(batch_size)

for i in range(total_iter):
    for j in range(5):
        z_sample=np.random.randn(batch_size,z_dim)
        d_loss2 = d_train_model.train_on_batch([next(img_generator),z_sample],None)
    for j in range(1):
        z_sample = np.random.randn(batch_size,z_dim)
        g_loss = g_train_model.train_on_batch(z_sample,None)
    if i%10==0:
        print('iter: %s, d_loss: %s, g_loss: %s' % (i, d_loss2, g_loss))
    if i%iter_per_sample ==0:
        # sample('sample/test_%s.png'%i)
        sample('samples/test_%s.png' % i)
        g_train_model.save_weights('./g_train_model.weights')

